{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "# Build a digit classifier app with TensorFlow Lite\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/lite/examples/digit_classifier/ml/mnist_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/lite/examples/digit_classifier/ml/mnist_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook shows an end-to-end example of training a TensorFlow model using Keras and Python, then export it to TensorFlow Lite format to use in mobile apps. Here we will train a handwritten digit classifier using MNIST dataset.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kS_mq4yAlXHZ"
      },
      "outputs": [],
      "source": [
        "# TensorFlow and tf.keras\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import math\n",
        "\n",
        "print(tf.__version__)\n",
        "\n",
        "# Helper function to display digit images\n",
        "def show_sample(images, labels, sample_count=25):\n",
        "  # Create a square with can fit {sample_count} images\n",
        "  grid_count = math.ceil(math.ceil(math.sqrt(sample_count)))\n",
        "  grid_count = min(grid_count, len(images), len(labels))\n",
        "  \n",
        "  plt.figure(figsize=(2*grid_count, 2*grid_count))\n",
        "  for i in range(sample_count):\n",
        "    plt.subplot(grid_count, grid_count, i+1)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    plt.imshow(images[i], cmap=plt.cm.gray)\n",
        "    plt.xlabel(labels[i])\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r0WlLrJcnwWp"
      },
      "source": [
        "## Download and explore the MNIST dataset\n",
        "The MNIST database contains 60,000 training images and 10,000 testing images of handwritten digits. We will use the dataset to demonstrate how to train a image classification model and convert it to TensorFlow Lite format.\n",
        "\n",
        "Each image in the MNIST dataset is a 28x28 grayscale image containing a digit.\n",
        "![MNIST sample](https://storage.googleapis.com/khanhlvg-public.appspot.com/digit-classifier/mnist.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B5REuMrblewK"
      },
      "outputs": [],
      "source": [
        "# Download MNIST dataset.\n",
        "mnist = keras.datasets.mnist\n",
        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
        "\n",
        "# If you can't download the MNIST dataset from Keras, please try again with an alternative method below\n",
        "# path = keras.utils.get_file('mnist.npz',\n",
        "#                             origin='https://s3.amazonaws.com/img-datasets/mnist.npz',\n",
        "#                             file_hash='8a61469f7ea1b51cbae51d4f78837e45')\n",
        "# with np.load(path, allow_pickle=True) as f:\n",
        "#   train_images, train_labels = f['x_train'], f['y_train']\n",
        "#   test_images, test_labels = f['x_test'], f['y_test']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "REY5lDDylpFh"
      },
      "outputs": [],
      "source": [
        "# Normalize the input image so that each pixel value is between 0 to 1.\n",
        "train_images = train_images / 255.0\n",
        "test_images = test_images / 255.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAOE84IplyWR"
      },
      "outputs": [],
      "source": [
        "# Show the first 25 images in the training dataset.\n",
        "show_sample(train_images, \n",
        "            ['Label: %s' % label for label in train_labels])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9v-Wp3TxpLX6"
      },
      "source": [
        "## Train a TensorFlow model to classify digit images\n",
        "We use Keras API to build a TensorFlow model that can classify the digit images. Please see this [tutorial](https://www.tensorflow.org/beta/tutorials/keras/basic_classification) if you are interested to learn more about how to build machine learning model with Keras and TensorFlow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IWgBGmaplzcp"
      },
      "outputs": [],
      "source": [
        "# Define the model architecture\n",
        "model = keras.Sequential([\n",
        "    keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    keras.layers.Dense(128, activation=tf.nn.relu),\n",
        "\n",
        "# Optional: You can replace the dense layer above with the convolution layers below to get higher accuracy.\n",
        "    # keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
        "    # keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "    # keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "    # keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "    # keras.layers.Dropout(0.25),\n",
        "    # keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    # keras.layers.Dense(128, activation=tf.nn.relu),\n",
        "    # keras.layers.Dropout(0.5),\n",
        "\n",
        "    keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "model.compile(optimizer='adam',\n",
        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V6SOZuLRmEzS"
      },
      "outputs": [],
      "source": [
        "# Train the digit classification model\n",
        "model.fit(train_images, train_labels, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "za35DFJ5uFkX"
      },
      "source": [
        "## Evaluate our model\n",
        "We run our digit classification model against our test dataset that the model hasn't seen during its training process. We want to confirm that the model didn't just remember the digits it saw but also generalize well to new images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sJI8nqFWmMwC"
      },
      "outputs": [],
      "source": [
        "# Evaluate the model using test dataset.\n",
        "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
        "\n",
        "print('Test accuracy:', test_acc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PdlXEyUImeXC"
      },
      "outputs": [],
      "source": [
        "# Predict the labels of digit images in our test dataset.\n",
        "predictions = model.predict(test_images)\n",
        "\n",
        "# Then plot the first 25 test images and their predicted labels.\n",
        "show_sample(test_images, \n",
        "            ['Predicted: %d' % np.argmax(result) for result in predictions])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AWROBI4iv9fY"
      },
      "source": [
        "## Convert the Keras model to TensorFlow Lite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2fXStjR4mzkR"
      },
      "outputs": [],
      "source": [
        "# Convert Keras model to TF Lite format.\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "# Save the TF Lite model as file\n",
        "f = open('mnist.tflite', \"wb\")\n",
        "f.write(tflite_model)\n",
        "f.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q_Z5yLxrwbpI"
      },
      "outputs": [],
      "source": [
        "# Download the digit classification model if you're using Colab, \n",
        "# or print the model's local path if you're not using Colab.\n",
        "try:\n",
        "  from google.colab import files\n",
        "  files.download('mnist.tflite')\n",
        "except ImportError:\n",
        "  import os\n",
        "  print('TF Lite model:', os.path.join(os.getcwd(), 'mnist.tflite'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3TvDxaYU2ui7"
      },
      "source": [
        "## Verify the TensorFlow Lite model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fFtIESwrx7cR"
      },
      "outputs": [],
      "source": [
        "# Download a test image\n",
        "zero_img_path = keras.utils.get_file(\n",
        "    'zero.png', \n",
        "    'https://storage.googleapis.com/khanhlvg-public.appspot.com/digit-classifier/zero.png'\n",
        ")\n",
        "image = keras.preprocessing.image.load_img(\n",
        "    zero_img_path,\n",
        "    color_mode = 'grayscale',\n",
        "    target_size=(28, 28),\n",
        "    interpolation='bilinear'\n",
        ")\n",
        "\n",
        "# Pre-process the image: Adding batch dimension and normalize the pixel value to [0..1]\n",
        "# In training, we feed images in a batch to the model to improve training speed, making the model input shape to be (BATCH_SIZE, 28, 28).\n",
        "# For inference, we still need to match the input shape with training, so we expand the input dimensions to (1, 28, 28) using np.expand_dims\n",
        "input_image = np.expand_dims(np.array(image, dtype=np.float32) / 255.0, 0)\n",
        "\n",
        "# Show the pre-processed input image\n",
        "show_sample(input_image, ['Input Image'], 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xPtbtEJ2uacB"
      },
      "outputs": [],
      "source": [
        "# Run inference with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "interpreter.set_tensor(interpreter.get_input_details()[0][\"index\"], input_image)\n",
        "interpreter.invoke()\n",
        "output = interpreter.tensor(interpreter.get_output_details()[0][\"index\"])()[0]\n",
        "\n",
        "# Print the model's classification result\n",
        "digit = np.argmax(output)\n",
        "print('Predicted Digit: %d\\nConfidence: %f' % (digit, output[digit]))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "mnist_tflite.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
